{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-alpha0\n",
      "sys.version_info(major=3, minor=7, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.0.3\n",
      "numpy 1.16.2\n",
      "pandas 0.24.2\n",
      "sklearn 0.20.3\n",
      "tensorflow 2.0.0-alpha0\n",
      "tensorflow.python.keras.api._v2.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "housing = fetch_california_housing()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(11610, 8) (11610,)\n",
      "(3870, 8) (3870,)\n",
      "(5160, 8) (5160,)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "x_train_all, x_test, y_train_all, y_test = train_test_split(\n",
    "    housing.data, housing.target, random_state = 7)\n",
    "x_train, x_valid, y_train, y_valid = train_test_split(\n",
    "    x_train_all, y_train_all, random_state = 11)\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_valid.shape, y_valid.shape)\n",
    "print(x_test.shape, y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(x_train)\n",
    "x_valid_scaled = scaler.transform(x_valid)\n",
    "x_test_scaled = scaler.transform(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"generate_csv\"\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "\n",
    "def save_to_csv(output_dir, data, name_prefix,\n",
    "                header=None, n_parts=10):\n",
    "    path_format = os.path.join(output_dir, \"{}_{:02d}.csv\")\n",
    "    filenames = []\n",
    "    \n",
    "    for file_idx, row_indices in enumerate(\n",
    "        np.array_split(np.arange(len(data)), n_parts)):\n",
    "        part_csv = path_format.format(name_prefix, file_idx)\n",
    "        filenames.append(part_csv)\n",
    "        with open(part_csv, \"wt\", encoding=\"utf-8\") as f:\n",
    "            if header is not None:\n",
    "                f.write(header + \"\\n\")\n",
    "            for row_index in row_indices:\n",
    "                f.write(\",\".join(\n",
    "                    [repr(col) for col in data[row_index]]))\n",
    "                f.write('\\n')\n",
    "    return filenames\n",
    "\n",
    "train_data = np.c_[x_train_scaled, y_train]\n",
    "valid_data = np.c_[x_valid_scaled, y_valid]\n",
    "test_data = np.c_[x_test_scaled, y_test]\n",
    "header_cols = housing.feature_names + [\"MidianHouseValue\"]\n",
    "header_str = \",\".join(header_cols)\n",
    "\n",
    "train_filenames = save_to_csv(output_dir, train_data, \"train\",\n",
    "                              header_str, n_parts=20)\n",
    "valid_filenames = save_to_csv(output_dir, valid_data, \"valid\",\n",
    "                              header_str, n_parts=10)\n",
    "test_filenames = save_to_csv(output_dir, test_data, \"test\",\n",
    "                             header_str, n_parts=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train filenames:\n",
      "['generate_csv/train_00.csv',\n",
      " 'generate_csv/train_01.csv',\n",
      " 'generate_csv/train_02.csv',\n",
      " 'generate_csv/train_03.csv',\n",
      " 'generate_csv/train_04.csv',\n",
      " 'generate_csv/train_05.csv',\n",
      " 'generate_csv/train_06.csv',\n",
      " 'generate_csv/train_07.csv',\n",
      " 'generate_csv/train_08.csv',\n",
      " 'generate_csv/train_09.csv',\n",
      " 'generate_csv/train_10.csv',\n",
      " 'generate_csv/train_11.csv',\n",
      " 'generate_csv/train_12.csv',\n",
      " 'generate_csv/train_13.csv',\n",
      " 'generate_csv/train_14.csv',\n",
      " 'generate_csv/train_15.csv',\n",
      " 'generate_csv/train_16.csv',\n",
      " 'generate_csv/train_17.csv',\n",
      " 'generate_csv/train_18.csv',\n",
      " 'generate_csv/train_19.csv']\n",
      "valid filenames:\n",
      "['generate_csv/valid_00.csv',\n",
      " 'generate_csv/valid_01.csv',\n",
      " 'generate_csv/valid_02.csv',\n",
      " 'generate_csv/valid_03.csv',\n",
      " 'generate_csv/valid_04.csv',\n",
      " 'generate_csv/valid_05.csv',\n",
      " 'generate_csv/valid_06.csv',\n",
      " 'generate_csv/valid_07.csv',\n",
      " 'generate_csv/valid_08.csv',\n",
      " 'generate_csv/valid_09.csv']\n",
      "test filenames:\n",
      "['generate_csv/test_00.csv',\n",
      " 'generate_csv/test_01.csv',\n",
      " 'generate_csv/test_02.csv',\n",
      " 'generate_csv/test_03.csv',\n",
      " 'generate_csv/test_04.csv',\n",
      " 'generate_csv/test_05.csv',\n",
      " 'generate_csv/test_06.csv',\n",
      " 'generate_csv/test_07.csv',\n",
      " 'generate_csv/test_08.csv',\n",
      " 'generate_csv/test_09.csv']\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "print(\"train filenames:\")\n",
    "pprint.pprint(train_filenames)\n",
    "print(\"valid filenames:\")\n",
    "pprint.pprint(valid_filenames)\n",
    "print(\"test filenames:\")\n",
    "pprint.pprint(test_filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'generate_csv/train_19.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_06.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_04.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_03.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_17.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_18.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_15.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_00.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_10.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_11.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_01.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_05.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_14.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_16.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_09.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_02.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_07.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_08.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_12.csv', shape=(), dtype=string)\n",
      "tf.Tensor(b'generate_csv/train_13.csv', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# 1. filename -> dataset\n",
    "# 2. read file -> dataset -> datasets -> merge\n",
    "# 3. parse csv\n",
    "\n",
    "filename_dataset = tf.data.Dataset.list_files(train_filenames)\n",
    "for filename in filename_dataset:\n",
    "    print(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b'0.4853051504718848,-0.8492418886278699,-0.06530126513877861,-0.023379656040017353,1.4974350551260218,-0.07790657783453239,-0.9023632702857819,0.7814514907892068,2.956'\n",
      "b'-0.32652634129448693,0.43236189741438374,-0.09345459539684739,-0.08402991822890092,0.8460035745154013,-0.0266316482653991,-0.5617679242614233,0.1422875991184281,2.431'\n",
      "b'-1.0591781535672364,1.393564736946074,-0.026331968874673636,-0.11006759528831847,-0.6138198966579805,-0.09695934953589447,0.3247131133362288,-0.037477245413977976,0.672'\n",
      "b'0.15782311132800697,0.43236189741438374,0.3379948076652917,-0.015880306122244434,-0.3733890577139493,-0.05305245634489608,0.8006134598360177,-1.2359095422966828,3.169'\n",
      "b'0.04326300977263167,-1.0895425985107923,-0.38878716774583305,-0.10789864528874438,-0.6818663605100649,-0.0723871014747467,-0.8883662012710817,0.8213992340186296,1.426'\n",
      "b'-0.7432054083470616,0.9129633171802288,-0.644320243857189,-0.1479096959813185,0.7398510909061499,0.11427691039226903,-0.7950524078397521,0.6815821327156534,1.438'\n",
      "b'2.2754266257529974,-1.249743071766074,1.0294788075585177,-0.17124431895714504,-0.45413752815175606,0.10527151658164971,-0.9023632702857819,0.9012947204774823,3.798'\n",
      "b'-0.2223565745313433,1.393564736946074,0.02991299565857307,0.0801452044790158,-0.509481985418118,-0.06238599304952824,-0.86503775291325,0.8613469772480595,2.0'\n",
      "b'2.2878417437355094,-1.8905449647872008,0.6607106467795992,-0.14964778023694128,-0.06672632728722275,0.44788055801575993,-0.5337737862320228,0.5667323709310584,3.59'\n",
      "b'-0.7543417158936074,-0.9293421252555106,-0.9212720434835953,0.1242806741969112,-0.5983960315181748,-0.18494335623235414,-0.8183808561975836,0.8513600414406984,1.717'\n",
      "b'-0.8246762898717912,-0.04823952235146133,-0.3448658166118309,-0.08477587145199328,0.5012348243315076,-0.034699996532417135,0.5300034588851571,-0.08741192445075467,0.717'\n",
      "b'-0.4394346460367383,0.1920611875314612,-0.39172440230167493,-0.06233787211356993,0.682692061270399,-0.012080008421921133,0.935918460311448,-1.2458964781040367,1.618'\n",
      "b'0.199384450496934,1.0731637904355105,-0.19840853933562783,-0.29328906965393414,-0.07852104768825069,0.018804888420646343,0.8006134598360177,-1.1510205879341566,1.99'\n",
      "b'0.3798565732727743,-1.5701440182766375,0.4541195259524651,-0.13374802152613807,-0.28356772542919806,-0.04747003172530946,-0.3191520613399599,-0.41698080609349797,1.901'\n",
      "b'0.18702261628258646,-0.20843999560674303,0.005869659830725365,-0.2645340092721605,-0.011381870020860852,-0.015878889894211247,0.05876880205693385,0.17224840654049697,0.84'\n"
     ]
    }
   ],
   "source": [
    "n_readers = 5\n",
    "dataset = filename_dataset.interleave(\n",
    "    lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
    "    cycle_length = n_readers\n",
    ")\n",
    "for line in dataset.take(15):\n",
    "    print(line.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<tf.Tensor: id=124, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=125, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=126, shape=(), dtype=float32, numpy=3.0>, <tf.Tensor: id=127, shape=(), dtype=string, numpy=b'4'>, <tf.Tensor: id=128, shape=(), dtype=float32, numpy=5.0>]\n"
     ]
    }
   ],
   "source": [
    "# tf.io.decode_csv(str, record_defaults)\n",
    "\n",
    "sample_str = '1,2,3,4,5'\n",
    "record_defaults = [\n",
    "    tf.constant(0, dtype=tf.int32),\n",
    "    0,\n",
    "    np.nan,\n",
    "    \"hello\",\n",
    "    tf.constant([])\n",
    "]\n",
    "parsed_fields = tf.io.decode_csv(sample_str, record_defaults)\n",
    "print(parsed_fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Field 4 is required but missing in record 0! [Op:DecodeCSV]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    parsed_fields = tf.io.decode_csv(',,,,', record_defaults)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expect 5 fields but have 7 in record 0 [Op:DecodeCSV]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    parsed_fields = tf.io.decode_csv('1,2,3,4,5,6,7', record_defaults)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: id=153, shape=(8,), dtype=float32, numpy=\n",
       " array([-0.9868721 ,  0.8328631 , -0.18684709, -0.1488895 , -0.45323023,\n",
       "        -0.11504996,  1.6730974 , -0.74654967], dtype=float32)>,\n",
       " <tf.Tensor: id=154, shape=(1,), dtype=float32, numpy=array([1.138], dtype=float32)>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def parse_csv_line(line, n_fields = 9):\n",
    "    defs = [tf.constant(np.nan)] * n_fields\n",
    "    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
    "    x = tf.stack(parsed_fields[0:-1])\n",
    "    y = tf.stack(parsed_fields[-1:])\n",
    "    return x, y\n",
    "\n",
    "parse_csv_line(b'-0.9868720801669367,0.832863080552588,-0.18684708416901633,-0.14888949288707784,-0.4532302419670616,-0.11504995754593579,1.6730974284189664,-0.7465496877362412,1.138',\n",
    "               n_fields=9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:\n",
      "<tf.Tensor: id=229, shape=(3, 8), dtype=float32, numpy=\n",
      "array([[ 0.09734604,  0.75276285, -0.20218964, -0.19547   , -0.40605137,\n",
      "         0.00678553, -0.81371516,  0.6566148 ],\n",
      "       [-0.32652634,  0.4323619 , -0.09345459, -0.08402992,  0.8460036 ,\n",
      "        -0.02663165, -0.56176794,  0.1422876 ],\n",
      "       [ 0.4240821 ,  0.91296333, -0.04437482, -0.15297213, -0.24727628,\n",
      "        -0.10539167,  0.86126745, -1.335779  ]], dtype=float32)>\n",
      "y:\n",
      "<tf.Tensor: id=230, shape=(3, 1), dtype=float32, numpy=\n",
      "array([[1.119],\n",
      "       [2.431],\n",
      "       [3.955]], dtype=float32)>\n",
      "x:\n",
      "<tf.Tensor: id=233, shape=(3, 8), dtype=float32, numpy=\n",
      "array([[ 0.48530516, -0.8492419 , -0.06530126, -0.02337966,  1.4974351 ,\n",
      "        -0.07790658, -0.90236324,  0.78145146],\n",
      "       [ 0.63034356,  1.8741661 , -0.06713215, -0.12543367, -0.19737554,\n",
      "        -0.02272263, -0.69240725,  0.72652334],\n",
      "       [-1.4803331 , -0.68904144, -0.35624704, -0.17255889, -0.82158846,\n",
      "        -0.13823092,  1.9157133 , -1.0211904 ]], dtype=float32)>\n",
      "y:\n",
      "<tf.Tensor: id=234, shape=(3, 1), dtype=float32, numpy=\n",
      "array([[2.956],\n",
      "       [2.419],\n",
      "       [0.928]], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "# 1. filename -> dataset\n",
    "# 2. read file -> dataset -> datasets -> merge\n",
    "# 3. parse csv\n",
    "def csv_reader_dataset(filenames, n_readers=5,\n",
    "                       batch_size=32, n_parse_threads=5,\n",
    "                       shuffle_buffer_size=10000):\n",
    "    dataset = tf.data.Dataset.list_files(filenames)\n",
    "    dataset = dataset.repeat()\n",
    "    dataset = dataset.interleave(\n",
    "        lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
    "        cycle_length = n_readers\n",
    "    )\n",
    "    dataset.shuffle(shuffle_buffer_size)\n",
    "    dataset = dataset.map(parse_csv_line,\n",
    "                          num_parallel_calls=n_parse_threads)\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset\n",
    "\n",
    "train_set = csv_reader_dataset(train_filenames, batch_size=3)\n",
    "for x_batch, y_batch in train_set.take(2):\n",
    "    print(\"x:\")\n",
    "    pprint.pprint(x_batch)\n",
    "    print(\"y:\")\n",
    "    pprint.pprint(y_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_set = csv_reader_dataset(train_filenames,\n",
    "                               batch_size = batch_size)\n",
    "valid_set = csv_reader_dataset(valid_filenames,\n",
    "                               batch_size = batch_size)\n",
    "test_set = csv_reader_dataset(test_filenames,\n",
    "                              batch_size = batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 2.6241 - val_loss: 1.1120\n",
      "Epoch 2/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 0.8670 - val_loss: 0.8067\n",
      "Epoch 3/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.7344 - val_loss: 0.7604\n",
      "Epoch 4/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.6742 - val_loss: 0.7299\n",
      "Epoch 5/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.6572 - val_loss: 0.6977\n",
      "Epoch 6/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.6285 - val_loss: 0.6710\n",
      "Epoch 7/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.6049 - val_loss: 0.6499\n",
      "Epoch 8/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.5813 - val_loss: 0.6281\n",
      "Epoch 9/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.5625 - val_loss: 0.6099\n",
      "Epoch 10/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.5458 - val_loss: 0.5940\n",
      "Epoch 11/100\n",
      "348/348 [==============================] - 2s 4ms/step - loss: 0.5371 - val_loss: 0.5806\n",
      "Epoch 12/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.5378 - val_loss: 0.5679\n",
      "Epoch 13/100\n",
      "348/348 [==============================] - 1s 3ms/step - loss: 0.5019 - val_loss: 0.5596\n",
      "Epoch 14/100\n",
      "348/348 [==============================] - 1s 3ms/step - loss: 0.5163 - val_loss: 0.5459\n",
      "Epoch 15/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4920 - val_loss: 0.5392\n",
      "Epoch 16/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.4869 - val_loss: 0.5332\n",
      "Epoch 17/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4916 - val_loss: 0.5275\n",
      "Epoch 18/100\n",
      "348/348 [==============================] - 2s 4ms/step - loss: 0.4789 - val_loss: 0.5205\n",
      "Epoch 19/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4913 - val_loss: 0.5207\n",
      "Epoch 20/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4698 - val_loss: 0.5092\n",
      "Epoch 21/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4691 - val_loss: 0.5055\n",
      "Epoch 22/100\n",
      "348/348 [==============================] - 2s 7ms/step - loss: 0.4668 - val_loss: 0.5010\n",
      "Epoch 23/100\n",
      "348/348 [==============================] - 3s 8ms/step - loss: 0.4623 - val_loss: 0.4972\n",
      "Epoch 24/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 0.4665 - val_loss: 0.4945\n",
      "Epoch 25/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 0.4605 - val_loss: 0.4967\n",
      "Epoch 26/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.4562 - val_loss: 0.4880\n",
      "Epoch 27/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4534 - val_loss: 0.4830\n",
      "Epoch 28/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.4467 - val_loss: 0.4856\n",
      "Epoch 29/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4513 - val_loss: 0.4831\n",
      "Epoch 30/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4477 - val_loss: 0.4769\n",
      "Epoch 31/100\n",
      "348/348 [==============================] - 1s 3ms/step - loss: 0.4456 - val_loss: 0.4745\n",
      "Epoch 32/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4395 - val_loss: 0.4724\n",
      "Epoch 33/100\n",
      "348/348 [==============================] - 2s 5ms/step - loss: 0.4529 - val_loss: 0.4739\n",
      "Epoch 34/100\n",
      "348/348 [==============================] - 2s 4ms/step - loss: 0.4327 - val_loss: 0.4740\n",
      "Epoch 35/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4381 - val_loss: 0.4702\n",
      "Epoch 36/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4454 - val_loss: 0.4644\n",
      "Epoch 37/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.4303 - val_loss: 0.4674\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation='relu',\n",
    "                       input_shape=[8]),\n",
    "    keras.layers.Dense(1),\n",
    "])\n",
    "model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n",
    "callbacks = [keras.callbacks.EarlyStopping(\n",
    "    patience=5, min_delta=1e-2)]\n",
    "\n",
    "history = model.fit(train_set,\n",
    "                    validation_data = valid_set,\n",
    "                    steps_per_epoch = 11160 // batch_size,\n",
    "                    validation_steps = 3870 // batch_size,\n",
    "                    epochs = 100,\n",
    "                    callbacks = callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "161/161 [==============================] - 0s 3ms/step - loss: 0.4601\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.46014384971642347"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_set, steps = 5160 // batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
